#! /usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2013 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.

from __future__ import division

import argparse, sys, os.path

import Bio.SubsMat.MatrixInfo
from Bio.Data import CodonTable

availScores = ','.join(Bio.SubsMat.MatrixInfo.available_matrices)
#class MyParser(argparse.ArgumentDefaultsHelpFormatter):
#  def format_epilog(self, formatter):
#      return self.epilog
      
parser = argparse.ArgumentParser(description = """Correct a short coding DNA read """
                                 """ using an Amino Acid reference.""",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  epilog="Available match scores: " + availScores +
         '.\n Available Genetic Codes: ' + ','.join(CodonTable.unambiguous_dna_by_name.keys()))

parser.add_argument("--output", choices = ["fancy", "plain", "sequence", "full-sequence"],
                    default = "plain", help="""Output options. 'plain' outputs the corrected nuclieotide
                    fragment and the corresponding AA reference on one line. 'sequence' outputs only
                    the nuclieotide sequence, and 'full-sequence' outputs the unmatched prefix and
                    suffix as well (in lowercase). 'fancy' print a detailed, human readable description.""")

parser.add_argument("-t", "--translation", default = "Standard",
                    help="""Name of Genetic Code translation table."""
                    """ NCBI table numbers are supported as well.""")

parser.add_argument("-m", "--match-scores", dest = "mscores", default = "blosum80",
                    metavar='', choices = Bio.SubsMat.MatrixInfo.available_matrices,
                    help="""Amino Acid match scores.""")

parser.add_argument("-c", "--correction-scores", dest="cscores", default="-10,-10,-100",
                    help="""Comma separated correction Penalties: Indel,Correction,StopCodon.""")

parser.add_argument("--ncandidates", default=30, type = int,
                    metavar='N', help=""".""")

parser.add_argument("--method", choices = ["denovo","database"], default="database",
                    help="""denovo: generate a single reference from the DNA sequences in the
                    references file (first argument). database: pick the closest sequence from the
                    (DNA) references file (using blast-like heuristics).""") 

parser.add_argument("--search-scores", dest="dnaScores", metavar="M,X,G,E,F",
                    default="10,-5,-6,-6,1", help="""Alignment scoring parameters
                    to use when searching for a database match:
                    match,mismatch,gapopen,gapextend,free-ends.""")

parser.add_argument('aaref', metavar = "PEPTIDE-OR-FILE",
                    help = """An amino acid reference sequence or a FASTA file.""")

parser.add_argument('seqs', metavar = 'SEQUENCE-OR-FILE',
                    help = "A nuclieotide sequence or a FASTA file containing reads to correct.")

options = parser.parse_args()

import sys

try :
  if options.translation.isdigit() :
    transTable = CodonTable.unambiguous_dna_by_id[int(options.translation)]
  else :
    transTable = CodonTable.unambiguous_dna_by_name[options.translation]
except KeyError,e:
  print >> sys.stderr, "No such translation table:",options.translation
  sys.exit(1)

matchScores = getattr(Bio.SubsMat.MatrixInfo, options.mscores)
if matchScores is None :
  print >> sys.stderr, "No such match scores table:", options.mscores
  sys.exit(1)

try :
  indelPenalty,correctionPenalty,stopCodonPenalty = [float(x) for x in options.cscores.split(',')]
except Exception,e:
  print >> sys.stderr, "Incorrect correction scores"
  sys.exit(1)
  

from biopy import align, calign, aalign
from biopy.bioutils import readFasta
from biopy.genericutils import fileFromName

geneticCode = [None]*(64)

ii = [(getattr(calign, c),c) for c in "AGCT"]
for k1,c1 in ii :
  for k2,c2 in ii :
    for k3,c3 in ii :
      x = transTable.forward_table.get(c1 + c2 + c3)
      if x is not None :
        geneticCode[16*k1 + 4*k2+ k3] = aalign.AAorder.index(x)
      else :
        geneticCode[16*k1 + 4*k2+ k3] = -1

for sc in transTable.stop_codons :
  k1,k2,k3 = [ii["AGCT".index(x)][0] for x in sc]     ; assert geneticCode[16*k1 + 4*k2+ k3] == -1

assert all([x is not None for x in geneticCode])

def getAAref(s) :
  return aaref, None

deNovoAA = 0
if os.path.exists(options.aaref) :

  seqsForRef = list(readFasta(fileFromName(options.aaref)))
  if options.method == "denovo" :
    if len(seqsForRef) < 2 :
      print >> sys.stderr, "Invalid seqeunces file for AA reference: (",options.aaref,")"
      sys.exit(1)

    def doTheCons(sqs, lengthQuant = 40) :
      als = [(lengthQuant*(len(s)//lengthQuant), s) for s in sqs]
      als = [x[1] for x in sorted(als, reverse=1)]
      al = align.seqMultiAlign(als)
      return al

    al = doTheCons([x[1] for x in seqsForRef])
    aaCons,fstart = aalign.aacons(al, geneticCode)
    sq = filter(lambda x: x != calign.GAP, aaCons[fstart:])
    aaref = [geneticCode[16*sq[3*k] + 4*sq[3*k+1] + sq[3*k+2]] for k in range(len(sq)//3)]
    assert all([x >= 0 for x in aaref]), aaref
    aaref = ''.join([aalign.AAorder[x] for x in aaref])
    deNovoAA = len(al)
    
  else :
    from biopy import cclust, otus
    refs = [x[1] for x in seqsForRef]
    matches = cclust.lookupTable(refs, removeSingles=0, native=1)
    report = calign.DIVERGENCE

    try :
      scores = align.parseScores(options.dnaScores)
    except RuntimeError,e:
      print >> sys.stderr, "**Error in alignment scores:",e.message
      sys.exit(1)

    fdis = lambda seq,j : calign.globalAlign(seq, refs[j], report = report, scores = scores)
    lseqs = [len(x) for x in refs]
    nMatchesLim = options.ncandidates
    ## if options.logrefs is not None :
    ##   logrefs = file(options.logrefs, 'w')
    ## else :
    ##   logrefs = None
    def getAAref(s) :
      mc = otus.getMates(s, refs, lseqs, 1, fdis, matches, lim = nMatchesLim)
      assert mc[0]
      c = sorted(zip(mc[1],mc[0]))[0]
      i = c[1]
      sq = refs[i]
      aaref = ''.join([transTable.forward_table[sq[3*k:3*k+3]] for k in range(len(sq)//3)])
      #global logrefs
      #if logrefs :
      #  print >> logrefs,
      return aaref,"%g\t%d\t%s" % (c[0], i, seqsForRef[i][0])
else :
  aaref = ''.join([x.upper() for x in options.aaref])
  if all([x in "AGCTN" for x in aaref]) :
    # assume DNA
    aaref = ''.join([transTable.forward_table[aaref[3*k:3*k+3]] for k in range(len(aaref)//3)])
  else :
    if not all([x in aalign.AAorder for x in aaref]) :
      print >> sys.stderr, "Invalid AA reference: (",options.aaref,")"
      sys.exit(1)
  
msc = [None]*(len(aalign.AAorder)**2)
for k1,c1 in enumerate(aalign.AAorder) :
  for k2,c2 in enumerate(aalign.AAorder) :
    k = (c1,c2)
    if k not in matchScores:
      k = c2,c1
    s = matchScores[k]
    assert msc[len(aalign.AAorder) * k1 + k2] is None
    msc[len(aalign.AAorder) * k1 + k2] = s

def showframes(a, seq) :
  aseq, aa, stats, frames = a
  aseq = align.iton(aseq)
  n = len(frames) ; assert 3*n == len(aseq)

  s = seq[stats["dnaFreeStart"]:-stats["dnaFreeEnd"]
          if stats["dnaFreeEnd"] else None]
  for rr in range((n+19)//20)  :
    r = list(range(20*rr,20*rr+20))
    print
    
    for k,i in enumerate(frames) :
      if k not in r : continue
      f = "%3s"
      if i > 3 :
        f = "%%%ds" % i
      print f % aseq[3*k:3*k+3],
    print

    l = 0
    for k,i in enumerate(frames) :
      if k in r : 

        if i > 3 :
          print '-xx-',
        elif 0 < i < 3 :
          print '+++',
        elif i == 3 :
          print ''.join([':' if x==y else '*' for x,y in zip(s[l:l+i], aseq[3*k:3*k+3])]),
        else :
          print '   ',
      l += i

    print    
    l = 0
    for k,i in enumerate(frames) :
      if k in r :
      
        f = "%3s"
        if i > 3 :
          f = "%%%ds" % i
        print f % s[l:l+i],
      l += i
    print

if os.path.exists(options.seqs) :

  if deNovoAA :
    print ";; De novo reference from",deNovoAA,"sequences:"
    print ";;", aaref
    print
    
  allSeqs = list(readFasta(fileFromName(options.seqs)))
  for nm,seq in allSeqs:
    aar,desc = getAAref(seq)
    res = aalign.acorrect(seq, aar, msc, geneticCode, indel = indelPenalty,
                          correction = correctionPenalty, stopCodon = stopCodonPenalty)

    if desc :
      print ';; ',desc
      
    stats = res[2]
    print ';;',stats["matches"],"matches,",stats["mismatches"],"mimatches,",stats["gaps"],"gaps."
    ci,cd = stats["correctionInsertions"], stats["correctionDeletions"]
    print ';;',ci+cd,"corrections:",ci,"insersions +",cd,"deletions."
    
    aa = ''.join([aalign.AAorder[x] if x < len(aalign.AAorder) else '-' for x in res[1]])
    print nm + '|' + aa
      
    s = align.iton(res[0])
    if options.output != "sequence" :
      stats = res[2]
      p = stats["dnaFreeStart"]
      dnapre = ''.join([x.lower() for x in seq[:p]]) if p else ''
      p = stats["dnaFreeEnd"]
      dnasuf = ''.join([x.lower() for x in seq[-p:]]) if p else ''
      s = dnapre + s + dnasuf
    print s
    print
    
else :
  seq = options.seqs
  if not all([x in "AGCTNagctn-" for x in seq]) :
    print >> sys.stderr, "Expecting a sequence or a fasta file name"
    sys.exit(1)

  aar,desc = getAAref(seq)
  res = aalign.acorrect(seq, aar, msc, geneticCode, indel = indelPenalty,
                        correction = correctionPenalty, stopCodon = stopCodonPenalty)
  
  s = align.iton(res[0])
  aa = ''.join([aalign.AAorder[x] if x < len(aalign.AAorder) else '-' for x in res[1]])
  if options.output == "plain" :
    print s, aa
  elif options.output == "sequence" :
    print s
  elif options.output == "full-sequence" :
    stats = res[2]
    p = stats["dnaFreeStart"]
    dnapre = ''.join([x.lower() for x in seq[:p]]) if p else ''
    p = stats["dnaFreeEnd"]
    dnasuf = ''.join([x.lower() for x in seq[-p:]]) if p else ''
    print dnapre + s + dnasuf
  elif options.output == "fancy" :
    if deNovoAA :
      print "De novo reference from",deNovoAA,"sequences:"
      print aar
    
    stats = res[2]
    print stats["matches"],"matches,",stats["mismatches"],"mimatches,",stats["gaps"],"gaps."
    ci,cd = stats["correctionInsertions"], stats["correctionDeletions"]

    print ci+cd,"corrections:",ci,"insersions +",cd,"deletions."
    dnapre = ''.join([x.lower() for x in seq[:stats["dnaFreeStart"]]])
    e = stats["dnaFreeEnd"]
    dnasuf = ''.join([x.lower() for x in seq[-e:]]) if e > 0 else ''
    aapre = ''.join([x.lower() for x in aar[:stats["aaFreeStart"]]])
    e = stats["aaFreeEnd"]
    aasuf = ''.join([x.lower() for x in aar[-e:]]) if e > 0 else ''
    lpre = max(len(dnapre),len(aapre))
    codons = [s[3*k:3*k+3] for k in range(len(s)//3)]
    tr = [transTable.forward_table[x] if x != '---' else '-' for x in codons]
    print ' '*lpre, ' '.join([' %c ' % x for x in tr])
    print ' '*(lpre-len(dnapre)) + dnapre, ' '.join(codons),dnasuf
    print ' '*(lpre-len(aapre)) + aapre,' '.join([' %c ' % x for x in aa]),aasuf

    showframes(res, seq)

#if options.logrefs is not None :
#  logrefs.close()

